{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Parameter Learning in Discrete Bayesian Networks" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this notebook, we show an example for learning the parameters (CPDs) of a Discrete Bayesian Network given the data and the model structure. pgmpy has two main methods for learning the parameters:\n", "1. MaximumLikelihood Estimator (pgmpy.estimators.MaximumLikelihoodEstimator)\n", "2. Bayesian Estimator (pgmpy.estimators.BayesianEstimator)\n", "3. Expectation Maximization (pgmpy.estimators.ExpectationMaximization)\n", "\n", "In the examples, we will try to generate some data from given models and then try to learn the model parameters back from the generated data." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Step 1: Generate some data" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Generating for node: CVP: 100%|██████████| 37/37 [00:01<00:00, 24.08it/s] \n" ] }, { "data": { "text/html": [ "
\n", " | HISTORY | \n", "CVP | \n", "PCWP | \n", "HYPOVOLEMIA | \n", "LVEDVOLUME | \n", "LVFAILURE | \n", "STROKEVOLUME | \n", "ERRLOWOUTPUT | \n", "HRBP | \n", "HREKG | \n", "... | \n", "MINVOLSET | \n", "VENTMACH | \n", "VENTTUBE | \n", "VENTLUNG | \n", "VENTALV | \n", "ARTCO2 | \n", "CATECHOL | \n", "HR | \n", "CO | \n", "BP | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "FALSE | \n", "NORMAL | \n", "NORMAL | \n", "FALSE | \n", "NORMAL | \n", "FALSE | \n", "NORMAL | \n", "FALSE | \n", "HIGH | \n", "HIGH | \n", "... | \n", "NORMAL | \n", "NORMAL | \n", "LOW | \n", "ZERO | \n", "ZERO | \n", "HIGH | \n", "HIGH | \n", "HIGH | \n", "HIGH | \n", "HIGH | \n", "
1 | \n", "FALSE | \n", "NORMAL | \n", "NORMAL | \n", "FALSE | \n", "NORMAL | \n", "FALSE | \n", "NORMAL | \n", "TRUE | \n", "LOW | \n", "LOW | \n", "... | \n", "NORMAL | \n", "NORMAL | \n", "LOW | \n", "ZERO | \n", "ZERO | \n", "HIGH | \n", "HIGH | \n", "LOW | \n", "LOW | \n", "LOW | \n", "
2 | \n", "FALSE | \n", "LOW | \n", "LOW | \n", "TRUE | \n", "LOW | \n", "TRUE | \n", "LOW | \n", "FALSE | \n", "HIGH | \n", "NORMAL | \n", "... | \n", "NORMAL | \n", "NORMAL | \n", "ZERO | \n", "LOW | \n", "HIGH | \n", "LOW | \n", "HIGH | \n", "HIGH | \n", "LOW | \n", "LOW | \n", "
3 | \n", "FALSE | \n", "NORMAL | \n", "NORMAL | \n", "FALSE | \n", "NORMAL | \n", "FALSE | \n", "NORMAL | \n", "FALSE | \n", "HIGH | \n", "HIGH | \n", "... | \n", "NORMAL | \n", "NORMAL | \n", "LOW | \n", "ZERO | \n", "ZERO | \n", "HIGH | \n", "HIGH | \n", "HIGH | \n", "HIGH | \n", "HIGH | \n", "
4 | \n", "FALSE | \n", "HIGH | \n", "HIGH | \n", "TRUE | \n", "HIGH | \n", "FALSE | \n", "NORMAL | \n", "TRUE | \n", "NORMAL | \n", "HIGH | \n", "... | \n", "NORMAL | \n", "NORMAL | \n", "ZERO | \n", "HIGH | \n", "LOW | \n", "HIGH | \n", "HIGH | \n", "HIGH | \n", "HIGH | \n", "HIGH | \n", "
5 rows × 37 columns
\n", "